{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp models.patchtst"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PatchTST"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The PatchTST model is an efficient Transformer-based model for multivariate time series forecasting.\n",
    "\n",
    "It is based on two key components:\n",
    "- segmentation of time series into windows (patches) which are served as input tokens to Transformer\n",
    "- channel-independence. where each channel contains a single univariate time series."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**References**<br>\n",
    "- [Nie, Y., Nguyen, N. H., Sinthong, P., & Kalagnanam, J. (2022). \"A Time Series is Worth 64 Words: Long-term Forecasting with Transformers\"](https://arxiv.org/pdf/2211.14730.pdf)<br>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Figure 1. PatchTST.](imgs_models/patchtst.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import math\n",
    "import numpy as np\n",
    "from typing import Optional #, Any, Tuple\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from neuralforecast.common._base_windows import BaseWindows\n",
    "\n",
    "from neuralforecast.losses.pytorch import MAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from fastcore.test import test_eq\n",
    "from nbdev.showdoc import show_doc"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Backbone"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Auxiliary Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class Transpose(nn.Module):\n",
    "    def __init__(self, *dims, contiguous=False): \n",
    "        super().__init__()\n",
    "        self.dims, self.contiguous = dims, contiguous\n",
    "    def forward(self, x):\n",
    "        if self.contiguous: return x.transpose(*self.dims).contiguous()\n",
    "        else: return x.transpose(*self.dims)\n",
    "\n",
    "def get_activation_fn(activation):\n",
    "    if callable(activation): return activation()\n",
    "    elif activation.lower() == \"relu\": return nn.ReLU()\n",
    "    elif activation.lower() == \"gelu\": return nn.GELU()\n",
    "    raise ValueError(f'{activation} is not available. You can use \"relu\", \"gelu\", or a callable') "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Positional Encoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def PositionalEncoding(q_len, hidden_size, normalize=True):\n",
    "    pe = torch.zeros(q_len, hidden_size)\n",
    "    position = torch.arange(0, q_len).unsqueeze(1)\n",
    "    div_term = torch.exp(torch.arange(0, hidden_size, 2) * -(math.log(10000.0) / hidden_size))\n",
    "    pe[:, 0::2] = torch.sin(position * div_term)\n",
    "    pe[:, 1::2] = torch.cos(position * div_term)\n",
    "    if normalize:\n",
    "        pe = pe - pe.mean()\n",
    "        pe = pe / (pe.std() * 10)\n",
    "    return pe\n",
    "\n",
    "SinCosPosEncoding = PositionalEncoding\n",
    "\n",
    "def Coord2dPosEncoding(q_len, hidden_size, exponential=False, normalize=True, eps=1e-3):\n",
    "    x = .5 if exponential else 1\n",
    "    i = 0\n",
    "    for i in range(100):\n",
    "        cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, hidden_size).reshape(1, -1) ** x) - 1\n",
    "        if abs(cpe.mean()) <= eps: break\n",
    "        elif cpe.mean() > eps: x += .001\n",
    "        else: x -= .001\n",
    "        i += 1\n",
    "    if normalize:\n",
    "        cpe = cpe - cpe.mean()\n",
    "        cpe = cpe / (cpe.std() * 10)\n",
    "    return cpe\n",
    "\n",
    "def Coord1dPosEncoding(q_len, exponential=False, normalize=True):\n",
    "    cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1)\n",
    "    if normalize:\n",
    "        cpe = cpe - cpe.mean()\n",
    "        cpe = cpe / (cpe.std() * 10)\n",
    "    return cpe\n",
    "\n",
    "def positional_encoding(pe, learn_pe, q_len, hidden_size):\n",
    "    # Positional encoding\n",
    "    if pe == None:\n",
    "        W_pos = torch.empty((q_len, hidden_size)) # pe = None and learn_pe = False can be used to measure impact of pe\n",
    "        nn.init.uniform_(W_pos, -0.02, 0.02)\n",
    "        learn_pe = False\n",
    "    elif pe == 'zero':\n",
    "        W_pos = torch.empty((q_len, 1))\n",
    "        nn.init.uniform_(W_pos, -0.02, 0.02)\n",
    "    elif pe == 'zeros':\n",
    "        W_pos = torch.empty((q_len, hidden_size))\n",
    "        nn.init.uniform_(W_pos, -0.02, 0.02)\n",
    "    elif pe == 'normal' or pe == 'gauss':\n",
    "        W_pos = torch.zeros((q_len, 1))\n",
    "        torch.nn.init.normal_(W_pos, mean=0.0, std=0.1)\n",
    "    elif pe == 'uniform':\n",
    "        W_pos = torch.zeros((q_len, 1))\n",
    "        nn.init.uniform_(W_pos, a=0.0, b=0.1)\n",
    "    elif pe == 'lin1d': W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True)\n",
    "    elif pe == 'exp1d': W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True)\n",
    "    elif pe == 'lin2d': W_pos = Coord2dPosEncoding(q_len, hidden_size, exponential=False, normalize=True)\n",
    "    elif pe == 'exp2d': W_pos = Coord2dPosEncoding(q_len, hidden_size, exponential=True, normalize=True)\n",
    "    elif pe == 'sincos': W_pos = PositionalEncoding(q_len, hidden_size, normalize=True)\n",
    "    else: raise ValueError(f\"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \\\n",
    "        'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)\")\n",
    "    return nn.Parameter(W_pos, requires_grad=learn_pe)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RevIN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class RevIN(nn.Module):\n",
    "    def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):\n",
    "        \"\"\"\n",
    "        :param num_features: the number of features or channels\n",
    "        :param eps: a value added for numerical stability\n",
    "        :param affine: if True, RevIN has learnable affine parameters\n",
    "        \"\"\"\n",
    "        super(RevIN, self).__init__()\n",
    "        self.num_features = num_features\n",
    "        self.eps = eps\n",
    "        self.affine = affine\n",
    "        self.subtract_last = subtract_last\n",
    "        if self.affine:\n",
    "            self._init_params()\n",
    "\n",
    "    def forward(self, x, mode:str):\n",
    "        if mode == 'norm':\n",
    "            self._get_statistics(x)\n",
    "            x = self._normalize(x)\n",
    "        elif mode == 'denorm':\n",
    "            x = self._denormalize(x)\n",
    "        else: raise NotImplementedError\n",
    "        return x\n",
    "\n",
    "    def _init_params(self):\n",
    "        # initialize RevIN params: (C,)\n",
    "        self.affine_weight = nn.Parameter(torch.ones(self.num_features))\n",
    "        self.affine_bias = nn.Parameter(torch.zeros(self.num_features))\n",
    "\n",
    "    def _get_statistics(self, x):\n",
    "        dim2reduce = tuple(range(1, x.ndim-1))\n",
    "        if self.subtract_last:\n",
    "            self.last = x[:,-1,:].unsqueeze(1)\n",
    "        else:\n",
    "            self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()\n",
    "        self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()\n",
    "\n",
    "    def _normalize(self, x):\n",
    "        if self.subtract_last:\n",
    "            x = x - self.last\n",
    "        else:\n",
    "            x = x - self.mean\n",
    "        x = x / self.stdev\n",
    "        if self.affine:\n",
    "            x = x * self.affine_weight\n",
    "            x = x + self.affine_bias\n",
    "        return x\n",
    "\n",
    "    def _denormalize(self, x):\n",
    "        if self.affine:\n",
    "            x = x - self.affine_bias\n",
    "            x = x / (self.affine_weight + self.eps*self.eps)\n",
    "        x = x * self.stdev\n",
    "        if self.subtract_last:\n",
    "            x = x + self.last\n",
    "        else:\n",
    "            x = x + self.mean\n",
    "        return x"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class PatchTST_backbone(nn.Module):\n",
    "    def __init__(self, c_in:int, c_out:int, input_size:int, h:int, patch_len:int, stride:int, max_seq_len:Optional[int]=1024, \n",
    "                 n_layers:int=3, hidden_size=128, n_heads=16, d_k:Optional[int]=None, d_v:Optional[int]=None,\n",
    "                 linear_hidden_size:int=256, norm:str='BatchNorm', attn_dropout:float=0., dropout:float=0., act:str=\"gelu\", key_padding_mask:str='auto',\n",
    "                 padding_var:Optional[int]=None, attn_mask:Optional[torch.Tensor]=None, res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False,\n",
    "                 pe:str='zeros', learn_pe:bool=True, fc_dropout:float=0., head_dropout = 0, padding_patch = None,\n",
    "                 pretrain_head:bool=False, head_type = 'flatten', individual = False, revin = True, affine = True, subtract_last = False):\n",
    "        \n",
    "        super().__init__()\n",
    "\n",
    "        # RevIn\n",
    "        self.revin = revin\n",
    "        if self.revin: self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last)\n",
    "\n",
    "        # Patching\n",
    "        self.patch_len = patch_len\n",
    "        self.stride = stride\n",
    "        self.padding_patch = padding_patch\n",
    "        patch_num = int((input_size - patch_len)/stride + 1)\n",
    "        if padding_patch == 'end': # can be modified to general case\n",
    "            self.padding_patch_layer = nn.ReplicationPad1d((0, stride)) \n",
    "            patch_num += 1\n",
    "\n",
    "        # Backbone \n",
    "        self.backbone = TSTiEncoder(c_in, patch_num=patch_num, patch_len=patch_len, max_seq_len=max_seq_len,\n",
    "                                n_layers=n_layers, hidden_size=hidden_size, n_heads=n_heads, d_k=d_k, d_v=d_v, linear_hidden_size=linear_hidden_size,\n",
    "                                attn_dropout=attn_dropout, dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,\n",
    "                                attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,\n",
    "                                pe=pe, learn_pe=learn_pe)\n",
    "\n",
    "        # Head\n",
    "        self.head_nf = hidden_size * patch_num\n",
    "        self.n_vars = c_in\n",
    "        self.c_out = c_out\n",
    "        self.pretrain_head = pretrain_head\n",
    "        self.head_type = head_type\n",
    "        self.individual = individual\n",
    "\n",
    "        if self.pretrain_head: \n",
    "            self.head = self.create_pretrain_head(self.head_nf, c_in, fc_dropout) # custom head passed as a partial func with all its kwargs\n",
    "        elif head_type == 'flatten': \n",
    "            self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, h, c_out, head_dropout=head_dropout)\n",
    "\n",
    "    def forward(self, z):                                                                   # z: [bs x nvars x seq_len]\n",
    "        # norm\n",
    "        if self.revin: \n",
    "            z = z.permute(0,2,1)\n",
    "            z = self.revin_layer(z, 'norm')\n",
    "            z = z.permute(0,2,1)\n",
    "\n",
    "        # do patching\n",
    "        if self.padding_patch == 'end':\n",
    "            z = self.padding_patch_layer(z)\n",
    "        z = z.unfold(dimension=-1, size=self.patch_len, step=self.stride)                   # z: [bs x nvars x patch_num x patch_len]\n",
    "        z = z.permute(0,1,3,2)                                                              # z: [bs x nvars x patch_len x patch_num]\n",
    "\n",
    "        # model\n",
    "        z = self.backbone(z)                                                                # z: [bs x nvars x hidden_size x patch_num]\n",
    "        z = self.head(z)                                                                    # z: [bs x nvars x h] \n",
    "\n",
    "        # denorm\n",
    "        if self.revin:\n",
    "            z = z.permute(0,2,1)\n",
    "            z = self.revin_layer(z, 'denorm')\n",
    "            z = z.permute(0,2,1)\n",
    "        return z\n",
    "    \n",
    "    def create_pretrain_head(self, head_nf, vars, dropout):\n",
    "        return nn.Sequential(nn.Dropout(dropout),\n",
    "                    nn.Conv1d(head_nf, vars, 1)\n",
    "                    )\n",
    "\n",
    "\n",
    "class Flatten_Head(nn.Module):\n",
    "    def __init__(self, individual, n_vars, nf, h, c_out, head_dropout=0):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.individual = individual\n",
    "        self.n_vars = n_vars\n",
    "        self.c_out = c_out\n",
    "        \n",
    "        if self.individual:\n",
    "            self.linears = nn.ModuleList()\n",
    "            self.dropouts = nn.ModuleList()\n",
    "            self.flattens = nn.ModuleList()\n",
    "            for i in range(self.n_vars):\n",
    "                self.flattens.append(nn.Flatten(start_dim=-2))\n",
    "                self.linears.append(nn.Linear(nf, h*c_out))\n",
    "                self.dropouts.append(nn.Dropout(head_dropout))\n",
    "        else:\n",
    "            self.flatten = nn.Flatten(start_dim=-2)\n",
    "            self.linear = nn.Linear(nf, h*c_out)\n",
    "            self.dropout = nn.Dropout(head_dropout)\n",
    "            \n",
    "    def forward(self, x):                                 # x: [bs x nvars x hidden_size x patch_num]\n",
    "        if self.individual:\n",
    "            x_out = []\n",
    "            for i in range(self.n_vars):\n",
    "                z = self.flattens[i](x[:,i,:,:])          # z: [bs x hidden_size * patch_num]\n",
    "                z = self.linears[i](z)                    # z: [bs x h]\n",
    "                z = self.dropouts[i](z)\n",
    "                x_out.append(z)\n",
    "            x = torch.stack(x_out, dim=1)                 # x: [bs x nvars x h]\n",
    "        else:\n",
    "            x = self.flatten(x)\n",
    "            x = self.linear(x)\n",
    "            x = self.dropout(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class TSTiEncoder(nn.Module):  #i means channel-independent\n",
    "    def __init__(self, c_in, patch_num, patch_len, max_seq_len=1024,\n",
    "                 n_layers=3, hidden_size=128, n_heads=16, d_k=None, d_v=None,\n",
    "                 linear_hidden_size=256, norm='BatchNorm', attn_dropout=0., dropout=0., act=\"gelu\", store_attn=False,\n",
    "                 key_padding_mask='auto', padding_var=None, attn_mask=None, res_attention=True, pre_norm=False,\n",
    "                 pe='zeros', learn_pe=True):\n",
    "        \n",
    "        \n",
    "        super().__init__()\n",
    "        \n",
    "        self.patch_num = patch_num\n",
    "        self.patch_len = patch_len\n",
    "        \n",
    "        # Input encoding\n",
    "        q_len = patch_num\n",
    "        self.W_P = nn.Linear(patch_len, hidden_size)        # Eq 1: projection of feature vectors onto a d-dim vector space\n",
    "        self.seq_len = q_len\n",
    "\n",
    "        # Positional encoding\n",
    "        self.W_pos = positional_encoding(pe, learn_pe, q_len, hidden_size)\n",
    "\n",
    "        # Residual dropout\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "        # Encoder\n",
    "        self.encoder = TSTEncoder(q_len, hidden_size, n_heads, d_k=d_k, d_v=d_v, linear_hidden_size=linear_hidden_size, norm=norm, attn_dropout=attn_dropout, dropout=dropout,\n",
    "                                   pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=n_layers, store_attn=store_attn)\n",
    "        \n",
    "    def forward(self, x) -> torch.Tensor:                                        # x: [bs x nvars x patch_len x patch_num]\n",
    "        \n",
    "        n_vars = x.shape[1]\n",
    "        # Input encoding\n",
    "        x = x.permute(0,1,3,2)                                                   # x: [bs x nvars x patch_num x patch_len]\n",
    "        x = self.W_P(x)                                                          # x: [bs x nvars x patch_num x hidden_size]\n",
    "\n",
    "        u = torch.reshape(x, (x.shape[0]*x.shape[1],x.shape[2],x.shape[3]))      # u: [bs * nvars x patch_num x hidden_size]\n",
    "        u = self.dropout(u + self.W_pos)                                         # u: [bs * nvars x patch_num x hidden_size]\n",
    "\n",
    "        # Encoder\n",
    "        z = self.encoder(u)                                                      # z: [bs * nvars x patch_num x hidden_size]\n",
    "        z = torch.reshape(z, (-1,n_vars,z.shape[-2],z.shape[-1]))                # z: [bs x nvars x patch_num x hidden_size]\n",
    "        z = z.permute(0,1,3,2)                                                   # z: [bs x nvars x hidden_size x patch_num]\n",
    "        \n",
    "        return z    \n",
    "            \n",
    "\n",
    "class TSTEncoder(nn.Module):\n",
    "    def __init__(self, q_len, hidden_size, n_heads, d_k=None, d_v=None, linear_hidden_size=None, \n",
    "                        norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',\n",
    "                        res_attention=False, n_layers=1, pre_norm=False, store_attn=False):\n",
    "        super().__init__()\n",
    "\n",
    "        self.layers = nn.ModuleList([TSTEncoderLayer(q_len, hidden_size, n_heads=n_heads, d_k=d_k, d_v=d_v, linear_hidden_size=linear_hidden_size, norm=norm,\n",
    "                                                      attn_dropout=attn_dropout, dropout=dropout,\n",
    "                                                      activation=activation, res_attention=res_attention,\n",
    "                                                      pre_norm=pre_norm, store_attn=store_attn) for i in range(n_layers)])\n",
    "        self.res_attention = res_attention\n",
    "\n",
    "    def forward(self, src:torch.Tensor, key_padding_mask:Optional[torch.Tensor]=None, attn_mask:Optional[torch.Tensor]=None):\n",
    "        output = src\n",
    "        scores = None\n",
    "        if self.res_attention:\n",
    "            for mod in self.layers: output, scores = mod(output, prev=scores, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "            return output\n",
    "        else:\n",
    "            for mod in self.layers: output = mod(output, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "            return output\n",
    "\n",
    "\n",
    "class TSTEncoderLayer(nn.Module):\n",
    "    def __init__(self, q_len, hidden_size, n_heads, d_k=None, d_v=None, linear_hidden_size=256, store_attn=False,\n",
    "                 norm='BatchNorm', attn_dropout=0, dropout=0., bias=True, activation=\"gelu\", res_attention=False, pre_norm=False):\n",
    "        super().__init__()\n",
    "        assert not hidden_size%n_heads, f\"hidden_size ({hidden_size}) must be divisible by n_heads ({n_heads})\"\n",
    "        d_k = hidden_size // n_heads if d_k is None else d_k\n",
    "        d_v = hidden_size // n_heads if d_v is None else d_v\n",
    "\n",
    "        # Multi-Head attention\n",
    "        self.res_attention = res_attention\n",
    "        self.self_attn = _MultiheadAttention(hidden_size, n_heads, d_k, d_v, attn_dropout=attn_dropout,\n",
    "                                             proj_dropout=dropout, res_attention=res_attention)\n",
    "\n",
    "        # Add & Norm\n",
    "        self.dropout_attn = nn.Dropout(dropout)\n",
    "        if \"batch\" in norm.lower():\n",
    "            self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(hidden_size), Transpose(1,2))\n",
    "        else:\n",
    "            self.norm_attn = nn.LayerNorm(hidden_size)\n",
    "\n",
    "        # Position-wise Feed-Forward\n",
    "        self.ff = nn.Sequential(nn.Linear(hidden_size, linear_hidden_size, bias=bias),\n",
    "                                get_activation_fn(activation),\n",
    "                                nn.Dropout(dropout),\n",
    "                                nn.Linear(linear_hidden_size, hidden_size, bias=bias))\n",
    "\n",
    "        # Add & Norm\n",
    "        self.dropout_ffn = nn.Dropout(dropout)\n",
    "        if \"batch\" in norm.lower():\n",
    "            self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(hidden_size), Transpose(1,2))\n",
    "        else:\n",
    "            self.norm_ffn = nn.LayerNorm(hidden_size)\n",
    "\n",
    "        self.pre_norm = pre_norm\n",
    "        self.store_attn = store_attn\n",
    "\n",
    "    def forward(self, src:torch.Tensor, prev:Optional[torch.Tensor]=None,\n",
    "                key_padding_mask:Optional[torch.Tensor]=None,\n",
    "                attn_mask:Optional[torch.Tensor]=None): # -> Tuple[torch.Tensor, Any]:\n",
    "\n",
    "        # Multi-Head attention sublayer\n",
    "        if self.pre_norm:\n",
    "            src = self.norm_attn(src)\n",
    "        ## Multi-Head attention\n",
    "        if self.res_attention:\n",
    "            src2, attn, scores = self.self_attn(src, src, src, prev,\n",
    "                                                key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "        else:\n",
    "            src2, attn = self.self_attn(src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "        if self.store_attn:\n",
    "            self.attn = attn\n",
    "        ## Add & Norm\n",
    "        src = src + self.dropout_attn(src2) # Add: residual connection with residual dropout\n",
    "        if not self.pre_norm:\n",
    "            src = self.norm_attn(src)\n",
    "\n",
    "        # Feed-forward sublayer\n",
    "        if self.pre_norm:\n",
    "            src = self.norm_ffn(src)\n",
    "        ## Position-wise Feed-Forward\n",
    "        src2 = self.ff(src)\n",
    "        ## Add & Norm\n",
    "        src = src + self.dropout_ffn(src2) # Add: residual connection with residual dropout\n",
    "        if not self.pre_norm:\n",
    "            src = self.norm_ffn(src)\n",
    "\n",
    "        if self.res_attention:\n",
    "            return src, scores\n",
    "        else:\n",
    "            return src\n",
    "\n",
    "\n",
    "class _MultiheadAttention(nn.Module):\n",
    "    def __init__(self, hidden_size, n_heads, d_k=None, d_v=None,\n",
    "                 res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False):\n",
    "        \"\"\"\n",
    "        Multi Head Attention Layer\n",
    "        Input shape:\n",
    "            Q:       [batch_size (bs) x max_q_len x hidden_size]\n",
    "            K, V:    [batch_size (bs) x q_len x hidden_size]\n",
    "            mask:    [q_len x q_len]\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        d_k = hidden_size // n_heads if d_k is None else d_k\n",
    "        d_v = hidden_size // n_heads if d_v is None else d_v\n",
    "\n",
    "        self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v\n",
    "\n",
    "        self.W_Q = nn.Linear(hidden_size, d_k * n_heads, bias=qkv_bias)\n",
    "        self.W_K = nn.Linear(hidden_size, d_k * n_heads, bias=qkv_bias)\n",
    "        self.W_V = nn.Linear(hidden_size, d_v * n_heads, bias=qkv_bias)\n",
    "\n",
    "        # Scaled Dot-Product Attention (multiple heads)\n",
    "        self.res_attention = res_attention\n",
    "        self.sdp_attn = _ScaledDotProductAttention(hidden_size, n_heads, attn_dropout=attn_dropout,\n",
    "                                                   res_attention=self.res_attention, lsa=lsa)\n",
    "\n",
    "        # Poject output\n",
    "        self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, hidden_size), nn.Dropout(proj_dropout))\n",
    "\n",
    "    def forward(self, Q:torch.Tensor, K:Optional[torch.Tensor]=None, V:Optional[torch.Tensor]=None, prev:Optional[torch.Tensor]=None,\n",
    "                key_padding_mask:Optional[torch.Tensor]=None, attn_mask:Optional[torch.Tensor]=None):\n",
    "\n",
    "        bs = Q.size(0)\n",
    "        if K is None: K = Q\n",
    "        if V is None: V = Q\n",
    "\n",
    "        # Linear (+ split in multiple heads)\n",
    "        q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2)       # q_s    : [bs x n_heads x max_q_len x d_k]\n",
    "        k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1)     # k_s    : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)\n",
    "        v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2)       # v_s    : [bs x n_heads x q_len x d_v]\n",
    "\n",
    "        # Apply Scaled Dot-Product Attention (multiple heads)\n",
    "        if self.res_attention:\n",
    "            output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s,\n",
    "                                                    prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "        else:\n",
    "            output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "        # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]\n",
    "\n",
    "        # back to the original inputs dimensions\n",
    "        output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]\n",
    "        output = self.to_out(output)\n",
    "\n",
    "        if self.res_attention: return output, attn_weights, attn_scores\n",
    "        else: return output, attn_weights\n",
    "\n",
    "\n",
    "class _ScaledDotProductAttention(nn.Module):\n",
    "    \"\"\"\n",
    "    Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer\n",
    "    (Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets\n",
    "    by Lee et al, 2021)\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, hidden_size, n_heads, attn_dropout=0., res_attention=False, lsa=False):\n",
    "        super().__init__()\n",
    "        self.attn_dropout = nn.Dropout(attn_dropout)\n",
    "        self.res_attention = res_attention\n",
    "        head_dim = hidden_size // n_heads\n",
    "        self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa)\n",
    "        self.lsa = lsa\n",
    "\n",
    "    def forward(self, q:torch.Tensor, k:torch.Tensor, v:torch.Tensor,\n",
    "                prev:Optional[torch.Tensor]=None, key_padding_mask:Optional[torch.Tensor]=None,\n",
    "                attn_mask:Optional[torch.Tensor]=None):\n",
    "        '''\n",
    "        Input shape:\n",
    "            q               : [bs x n_heads x max_q_len x d_k]\n",
    "            k               : [bs x n_heads x d_k x seq_len]\n",
    "            v               : [bs x n_heads x seq_len x d_v]\n",
    "            prev            : [bs x n_heads x q_len x seq_len]\n",
    "            key_padding_mask: [bs x seq_len]\n",
    "            attn_mask       : [1 x seq_len x seq_len]\n",
    "        Output shape:\n",
    "            output:  [bs x n_heads x q_len x d_v]\n",
    "            attn   : [bs x n_heads x q_len x seq_len]\n",
    "            scores : [bs x n_heads x q_len x seq_len]\n",
    "        '''\n",
    "\n",
    "        # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence\n",
    "        attn_scores = torch.matmul(q, k) * self.scale      # attn_scores : [bs x n_heads x max_q_len x q_len]\n",
    "\n",
    "        # Add pre-softmax attention scores from the previous layer (optional)\n",
    "        if prev is not None: attn_scores = attn_scores + prev\n",
    "\n",
    "        # Attention mask (optional)\n",
    "        if attn_mask is not None:                                     # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len\n",
    "            if attn_mask.dtype == torch.bool:\n",
    "                attn_scores.masked_fill_(attn_mask, -np.inf)\n",
    "            else:\n",
    "                attn_scores += attn_mask\n",
    "\n",
    "        # Key padding mask (optional)\n",
    "        if key_padding_mask is not None:                              # mask with shape [bs x q_len] (only when max_w_len == q_len)\n",
    "            attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)\n",
    "\n",
    "        # normalize the attention weights\n",
    "        attn_weights = F.softmax(attn_scores, dim=-1)                 # attn_weights   : [bs x n_heads x max_q_len x q_len]\n",
    "        attn_weights = self.attn_dropout(attn_weights)\n",
    "\n",
    "        # compute the new values given the attention weights\n",
    "        output = torch.matmul(attn_weights, v)                        # output: [bs x n_heads x max_q_len x d_v]\n",
    "\n",
    "        if self.res_attention: return output, attn_weights, attn_scores\n",
    "        else: return output, attn_weights"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class PatchTST(BaseWindows):\n",
    "    \"\"\" PatchTST\n",
    "\n",
    "    The PatchTST model is an efficient Transformer-based model for multivariate time series forecasting.\n",
    "\n",
    "    It is based on two key components:\n",
    "    - segmentation of time series into windows (patches) which are served as input tokens to Transformer\n",
    "    - channel-independence, where each channel contains a single univariate time series.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `h`: int, Forecast horizon. <br>\n",
    "    `input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].<br>\n",
    "    `stat_exog_list`: str list, static exogenous columns.<br>\n",
    "    `hist_exog_list`: str list, historic exogenous columns.<br>\n",
    "    `futr_exog_list`: str list, future exogenous columns.<br>\n",
    "    `exclude_insample_y`: bool=False, the model skips the autoregressive features y[t-input_size:t] if True.<br>\n",
    "    `encoder_layers`: int, number of layers for encoder.<br>\n",
    "    `n_heads`: int=16, number of multi-head's attention.<br>\n",
    "    `hidden_size`: int=128, units of embeddings and encoders.<br>\n",
    "    `linear_hidden_size`: int=256, units of linear layer.<br>\n",
    "    `dropout`: float=0.1, dropout rate for residual connection.<br>\n",
    "    `fc_dropout`: float=0.1, dropout rate for linear layer.<br>\n",
    "    `head_dropout`: float=0.1, dropout rate for Flatten head layer.<br>\n",
    "    `attn_dropout`: float=0.1, dropout rate for attention layer.<br>\n",
    "    `patch_len`: int=32, length of patch.<br>\n",
    "    `stride`: int=16, stride of patch.<br>\n",
    "    `revin`: bool=True, bool to use RevIn.<br>\n",
    "    `revin_affine`: bool=False, bool to use affine in RevIn.<br>\n",
    "    `revin_substract_last`: bool=False, bool to use substract last in RevIn.<br>\n",
    "    `activation`: str='ReLU', activation from ['gelu','relu'].<br>\n",
    "    `res_attention`: bool=False, bool to use residual attention.<br>\n",
    "    `batch_normalization`: bool=False, bool to use batch normalization.<br>\n",
    "    `learn_pos_embedding`: bool=True, bool to learn positional embedding.<br>\n",
    "    `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `max_steps`: int=1000, maximum number of training steps.<br>\n",
    "    `learning_rate`: float=1e-3, Learning rate between (0, 1).<br>\n",
    "    `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.<br>\n",
    "    `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.<br>\n",
    "    `val_check_steps`: int=100, Number of training steps between every validation loss check.<br>\n",
    "    `batch_size`: int=32, number of different series in each batch.<br>\n",
    "    `valid_batch_size`: int=None, number of different series in each validation and test batch, if None uses batch_size.<br>\n",
    "    `windows_batch_size`: int=1024, number of windows to sample in each training batch, default uses all.<br>\n",
    "    `inference_windows_batch_size`: int=1024, number of windows to sample in each inference batch.<br>\n",
    "    `start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.<br>\n",
    "    `step_size`: int=1, step size between each window of temporal data.<br>\n",
    "    `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>\n",
    "    `random_seed`: int, random_seed for pytorch initializer and numpy generators.<br>\n",
    "    `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>\n",
    "    `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
    "    `alias`: str, optional,  Custom name of the model.<br>\n",
    "    `**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>    \n",
    "\n",
    "    **References:**<br>\n",
    "    -[Nie, Y., Nguyen, N. H., Sinthong, P., & Kalagnanam, J. (2022). \"A Time Series is Worth 64 Words: Long-term Forecasting with Transformers\"](https://arxiv.org/pdf/2211.14730.pdf)\n",
    "    \"\"\"\n",
    "    # Class attributes\n",
    "    SAMPLING_TYPE = 'windows'\n",
    "\n",
    "    def __init__(self,\n",
    "                 h,\n",
    "                 input_size,\n",
    "                 stat_exog_list = None,\n",
    "                 hist_exog_list = None,\n",
    "                 futr_exog_list = None,\n",
    "                 exclude_insample_y = False,\n",
    "                 encoder_layers: int = 3,\n",
    "                 n_heads: int = 16,\n",
    "                 hidden_size: int = 128,\n",
    "                 linear_hidden_size: int = 256,\n",
    "                 dropout: float = 0.2,\n",
    "                 fc_dropout: float = 0.2,\n",
    "                 head_dropout: float = 0.0,\n",
    "                 attn_dropout: float = 0.,\n",
    "                 patch_len: int = 16,\n",
    "                 stride: int = 8,\n",
    "                 revin: bool = True,\n",
    "                 revin_affine: bool = False,\n",
    "                 revin_subtract_last: bool = True,\n",
    "                 activation: str = \"gelu\",\n",
    "                 res_attention: bool = True, \n",
    "                 batch_normalization: bool = False,\n",
    "                 learn_pos_embed: bool = True,\n",
    "                 loss = MAE(),\n",
    "                 valid_loss = None,\n",
    "                 max_steps: int = 5000,\n",
    "                 learning_rate: float = 1e-4,\n",
    "                 num_lr_decays: int = -1,\n",
    "                 early_stop_patience_steps: int =-1,\n",
    "                 val_check_steps: int = 100,\n",
    "                 batch_size: int = 32,\n",
    "                 valid_batch_size: Optional[int] = None,\n",
    "                 windows_batch_size = 1024,\n",
    "                 inference_windows_batch_size: int = 1024,\n",
    "                 start_padding_enabled = False,\n",
    "                 step_size: int = 1,\n",
    "                 scaler_type: str = 'identity',\n",
    "                 random_seed: int = 1,\n",
    "                 num_workers_loader: int = 0,\n",
    "                 drop_last_loader: bool = False,\n",
    "                 **trainer_kwargs):\n",
    "        super(PatchTST, self).__init__(h=h,\n",
    "                                       input_size=input_size,\n",
    "                                       hist_exog_list=hist_exog_list,\n",
    "                                       stat_exog_list=stat_exog_list,\n",
    "                                       futr_exog_list = futr_exog_list,\n",
    "                                       exclude_insample_y = exclude_insample_y,\n",
    "                                       loss=loss,\n",
    "                                       valid_loss=valid_loss,\n",
    "                                       max_steps=max_steps,\n",
    "                                       learning_rate=learning_rate,\n",
    "                                       num_lr_decays=num_lr_decays,\n",
    "                                       early_stop_patience_steps=early_stop_patience_steps,\n",
    "                                       val_check_steps=val_check_steps,\n",
    "                                       batch_size=batch_size,\n",
    "                                       valid_batch_size=valid_batch_size,\n",
    "                                       windows_batch_size=windows_batch_size,\n",
    "                                       inference_windows_batch_size=inference_windows_batch_size,\n",
    "                                       start_padding_enabled=start_padding_enabled,\n",
    "                                       step_size=step_size,\n",
    "                                       scaler_type=scaler_type,\n",
    "                                       num_workers_loader=num_workers_loader,\n",
    "                                       drop_last_loader=drop_last_loader,\n",
    "                                       random_seed=random_seed,\n",
    "                                       **trainer_kwargs) \n",
    "        # Asserts\n",
    "        if stat_exog_list is not None:\n",
    "            raise Exception(\"PatchTST does not yet support static exogenous variables\")\n",
    "        if futr_exog_list is not None:\n",
    "            raise Exception(\"PatchTST does not yet support future exogenous variables\")\n",
    "        if hist_exog_list is not None:\n",
    "            raise Exception(\"PatchTST does not yet support historical exogenous variables\")\n",
    "        \n",
    "        c_out = self.loss.outputsize_multiplier\n",
    "\n",
    "        # Fixed hyperparameters\n",
    "        c_in = 1                  # Always univariate\n",
    "        padding_patch='end'       # Padding at the end\n",
    "        pretrain_head = False     # No pretrained head\n",
    "        norm = 'BatchNorm'        # Use BatchNorm (if batch_normalization is True)\n",
    "        pe = 'zeros'              # Initial zeros for positional encoding \n",
    "        d_k = None                # Key dimension\n",
    "        d_v = None                # Value dimension\n",
    "        store_attn = False        # Store attention weights\n",
    "        head_type = 'flatten'     # Head type\n",
    "        individual = False        # Separate heads for each time series\n",
    "        max_seq_len = 1024        # Not used\n",
    "        key_padding_mask = 'auto' # Not used\n",
    "        padding_var = None        # Not used\n",
    "        attn_mask = None          # Not used\n",
    "\n",
    "        self.model = PatchTST_backbone(c_in=c_in, c_out=c_out, input_size=input_size, h=h, patch_len=patch_len, stride=stride, \n",
    "                                max_seq_len=max_seq_len, n_layers=encoder_layers, hidden_size=hidden_size,\n",
    "                                n_heads=n_heads, d_k=d_k, d_v=d_v, linear_hidden_size=linear_hidden_size, norm=norm, attn_dropout=attn_dropout,\n",
    "                                dropout=dropout, act=activation, key_padding_mask=key_padding_mask, padding_var=padding_var, \n",
    "                                attn_mask=attn_mask, res_attention=res_attention, pre_norm=batch_normalization, store_attn=store_attn,\n",
    "                                pe=pe, learn_pe=learn_pos_embed, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch,\n",
    "                                pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=revin_affine,\n",
    "                                subtract_last=revin_subtract_last)\n",
    "    \n",
    "    \n",
    "    def forward(self, windows_batch):  # x: [batch, input_size]\n",
    "\n",
    "        # Parse windows_batch\n",
    "        insample_y    = windows_batch['insample_y']\n",
    "        #insample_mask = windows_batch['insample_mask']\n",
    "        #hist_exog     = windows_batch['hist_exog']\n",
    "        #stat_exog     = windows_batch['stat_exog']\n",
    "        #futr_exog     = windows_batch['futr_exog']\n",
    "\n",
    "        # Add dimension for channel\n",
    "        x = insample_y.unsqueeze(-1) # [Ws,L,1]\n",
    "\n",
    "        x = x.permute(0,2,1)    # x: [Batch, 1, input_size]\n",
    "        x = self.model(x)\n",
    "        x = x.reshape(x.shape[0], self.h, -1) # x: [Batch, h, c_out]\n",
    "\n",
    "        # Domain map\n",
    "        forecast = self.loss.domain_map(x)\n",
    "        \n",
    "        return forecast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(PatchTST)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(PatchTST.fit, name='PatchTST.fit')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(PatchTST.predict, name='PatchTST.predict')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pytorch_lightning as pl\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.models import MLP\n",
    "from neuralforecast.losses.pytorch import MQLoss, DistributionLoss\n",
    "from neuralforecast.tsdataset import TimeSeriesDataset\n",
    "from neuralforecast.utils import AirPassengers, AirPassengersPanel, AirPassengersStatic, augment_calendar_df\n",
    "\n",
    "AirPassengersPanel, calendar_cols = augment_calendar_df(df=AirPassengersPanel, freq='M')\n",
    "\n",
    "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train\n",
    "Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
    "\n",
    "model = PatchTST(h=12,\n",
    "                 input_size=104,\n",
    "                 patch_len=24,\n",
    "                 stride=24,\n",
    "                 revin=False,\n",
    "                 hidden_size=16,\n",
    "                 n_heads=4,\n",
    "                 scaler_type='robust',\n",
    "                 loss=DistributionLoss(distribution='StudentT', level=[80, 90]),\n",
    "                 #loss=MAE(),\n",
    "                 learning_rate=1e-3,\n",
    "                 max_steps=500,\n",
    "                 val_check_steps=50,\n",
    "                 early_stop_patience_steps=2)\n",
    "\n",
    "nf = NeuralForecast(\n",
    "    models=[model],\n",
    "    freq='M'\n",
    ")\n",
    "nf.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n",
    "forecasts = nf.predict(futr_df=Y_test_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])\n",
    "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
    "plot_df = pd.concat([Y_train_df, plot_df])\n",
    "\n",
    "if model.loss.is_distribution_output:\n",
    "    plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "    plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "    plt.plot(plot_df['ds'], plot_df['PatchTST-median'], c='blue', label='median')\n",
    "    plt.fill_between(x=plot_df['ds'][-12:], \n",
    "                    y1=plot_df['PatchTST-lo-90'][-12:].values, \n",
    "                    y2=plot_df['PatchTST-hi-90'][-12:].values,\n",
    "                    alpha=0.4, label='level 90')\n",
    "    plt.grid()\n",
    "    plt.legend()\n",
    "    plt.plot()\n",
    "else:\n",
    "    plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "    plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "    plt.plot(plot_df['ds'], plot_df['PatchTST'], c='blue', label='Forecast')\n",
    "    plt.legend()\n",
    "    plt.grid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])\n",
    "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
    "plot_df = pd.concat([Y_train_df, plot_df])\n",
    "\n",
    "if model.loss.is_distribution_output:\n",
    "    plot_df = plot_df[plot_df.unique_id=='Airline2'].drop('unique_id', axis=1)\n",
    "    plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "    plt.plot(plot_df['ds'], plot_df['PatchTST-median'], c='blue', label='median')\n",
    "    plt.fill_between(x=plot_df['ds'][-12:], \n",
    "                    y1=plot_df['PatchTST-lo-90'][-12:].values, \n",
    "                    y2=plot_df['PatchTST-hi-90'][-12:].values,\n",
    "                    alpha=0.4, label='level 90')\n",
    "    plt.grid()\n",
    "    plt.legend()\n",
    "    plt.plot()\n",
    "else:\n",
    "    plot_df = plot_df[plot_df.unique_id=='Airline2'].drop('unique_id', axis=1)\n",
    "    plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "    plt.plot(plot_df['ds'], plot_df['PatchTST'], c='blue', label='Forecast')\n",
    "    plt.legend()\n",
    "    plt.grid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
